#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
# sys.setdefaultencoding() does not exist, here!
reload(sys)  # Reload does the trick!
sys.setdefaultencoding('UTF8')

import jellyfish
import fileinput
import functools
import multiprocessing
import re
import time
import sys
import os
import cPickle
import Queue


from whoosh.index import open_dir, os
from whoosh.query import spans
from whoosh import query
from nltk import word_tokenize, bigrams, ngrams
from nltk.corpus import stopwords
from collections import defaultdict
import os, sys
# sys.path.append(sys.path.append(root-dir))
from Sentence import Sentence

NUM_CPUS = 1

# relational words used in calculating the set C and D with the proximity PMI

founded_unigrams = ['founder', 'co-founder', 'cofounder', 'co-founded',
                    'cofounded', 'founded', 'founders']

founded_bigrams = ['started by', 'established by']
founded_trigrams = []

acquired_unigrams = ['owns', 'acquired', 'acquires', 'acquire', 'acquiring', 'acquisition',
                     'bought', 'buy', 'buying', 'buys'
                     'purchase', 'purchased',
                     'takeover', 'tookover',
                     'merger', 'merge', 'merged', "buyout", "'s",
                     'parent', 'owner', 'subsidiary', 'owned', "absorption", "absorbed"]
                     

acquired_bigrams = ['take over', 'took over', 'takes over', 'takeover of',
                    'to buy', 'is buying', 'to purchase', 'purchase of'
                    'merge with', 'merged with',
                    'bid for', 'owned by', "parent company"]
                    

acquired_trigrams = ["'s purchase of", "agreed to purchase", "agree to purchase", "agree to buy", "agreed to buy"
                     "agreed to merge", "agree to merge", "a unit of", "parent company of", 'the parent of']


headquarters_unigrams = ['headquarters', 'headquartered', 'headquarter', 'offices', 'office',
                         'building', 'buildings', 'factory', 'plant', 'compound', 'situated', 'HQ' ]

headquarters_bigrams = ['based in', 'located in', 'main office', ' main offices',
                        'offices in', 'building in','office in', 'branch in',
                        'store in', 'firm in', 'factory in', 'plant in',
                        'head office', 'head offices', 'in central',
                        'in downtown', 'outskirts of', 'suburbs of', 'situated in', 'campus in']
headquarters_trigrams = []

employment_unigrams = ['chief', 'scientist', 'professor', 'biologist', 'ceo',
                       'CEO', 'employer', 'president', 'chairman', 'executive', 'officer',
                       'director', 'researcher', 'dean', 'fellow', 'head', 'economist', 'faculty',
                       'physicist', 'partner',
                       'founder', 'co-founder', 'cofounder', 'co-founded',
                       'cofounded', 'founded', 'founders' 'resign', 'work', 'ambassador', 'ambassadors', 'minister',
                       'diplomat', 'diplomats',
                       'boss', 'Chairman', 'manager', 'mogul', 'Manager', 'owner', 'coach',
                       'President', 'Chief', 'Coach', 'Secretary', 'secretary', 'principal',
                        'administrator', 'GM', 'administrator', 'Director', 'leader', 'assistant',
                        'governor', 'Minister', 'commissioner', 'fired',
                        'goalkeeper', 'player', 'players', 'commander', 'keeper', 'analyst', 'captain', 'driver', 'admiral', "'s"]

employment_bigrams = ['work for', 'retire from', 'associated with', 'resigns from', 'partner of', 'lead at',
                      'worked for', 'retired from',  'resigned from',
                      'controlled by', 'Secretary-General', 'headed by', 'members of', 'led by', "'s behalf"]
employment_trigrams = []

bad_tokens = [",", "(", ")", ";", "''",  "``", "'s", "-", "vs.", "v", "'", ":",
              ".", "--"]
			  
stopwords_list = stopwords.words('english')
not_valid = bad_tokens + stopwords_list

# PMI value for proximity
PMI = 0.5

# Parameters for relationship extraction from Sentence
MAX_TOKENS_AWAY = 6
MIN_TOKENS_AWAY = 1
CONTEXT_WINDOW = 2

# DEBUG stuff
PRINT_NOT_FOUND = False

# stores all variations matched with database
manager = multiprocessing.Manager()
all_in_database = manager.dict()


class ExtractedFact(object):
    def __init__(self, _e1, _e2, _score, _bef, _bet, _aft, _sentence,
                 _passive_voice, _rel, _tid=None):
        self.e1 = _e1
        self.e2 = _e2
        self.score = _score
        self.bef_words = _bef
        self.bet_words = _bet
        self.aft_words = _aft
        self.sentence = _sentence
        self.passive_voice = _passive_voice
        self.rel = _rel
        self.tid=_tid

    def __cmp__(self, other):
            if other.score > self.score:
                return -1
            elif other.score < self.score:
                return 1
            else:
                return 0

    def __hash__(self):
        sig = hash(self.e1) ^ hash(self.e2) ^ hash(self.bef_words) ^ \
              hash(self.bet_words) ^ hash(self.aft_words) ^ \
              hash(self.score) ^ hash(self.sentence) ^ hash(self.rel) ^ hash(self.tid)
        return sig

    def __eq__(self, other):
        if self.e1 == other.e1 and \
           self.e2 == other.e2 and \
           self.score == other.score and \
           self.bef_words == other.bef_words and \
           self.bet_words == other.bet_words and \
           self.aft_words == other.aft_words and \
           self.sentence == other.sentence and \
           self.tid == other.tid:
            return True
        else:
            return False


# ###########################################
# Misc., Utils, parsing corpus into memory #
# ###########################################

def timecall(f):
    @functools.wraps(f)
    def wrapper(*args, **kw):
        start = time.time()
        result = f(*args, **kw)
        end = time.time()        
        print "Time taken: %.2f seconds" % (end - start)
        return result

    return wrapper


def is_acronym(entity):
    if len(entity.split()) == 1 and entity.isupper():
        return True
    else:
        return False


def process_corpus(queue, g_dash, e1_type, e2_type):
    count = 0
    added = 0
    while True:
        try:
            if count % 25000 == 0:
                print multiprocessing.current_process(), \
                    "In Queue", queue.qsize(), "Total added: ", added
            line = queue.get_nowait()
            s = Sentence(line.strip(), e1_type, e2_type, MAX_TOKENS_AWAY,
                         MIN_TOKENS_AWAY, CONTEXT_WINDOW)
            for r in s.relationships:
                tokens = word_tokenize(str(" ".join([x[0] for x in r.between])).strip())
                if all(x in not_valid for x in word_tokenize(str(" ".join([x[0] for x in r.between])).strip())):
                    continue
                elif "," in tokens and tokens[0] != ',':
                    continue
                else:
                    r.between = str(" ".join([x[0] for x in r.between])).strip()
                    r.before = str(" ".join([x[0] for x in r.before])).strip()
                    r.after = str(" ".join([x[0] for x in r.after])).strip()
                    g_dash.append(r)
                    added += 1
            count += 1
        except Queue.Empty:
            break


def process_output(data, threshold, rel_type):
    """
    parses the file with the relationships extracted by the system
    each relationship is transformed into a ExtracteFact class
    """
    system_output = list()
    for line in fileinput.input(data):        
        if line.startswith('instance'):
            instance_parts, score = line.split("score:")
            e1, e2 = instance_parts.split("instance:")[1].strip().split('\t')

        if line.startswith('sentence'):
            sentence = line.split("sentence:")[1].strip()

        if line.startswith('pattern_bef:'):
            bef = line.split("pattern_bef:")[1].strip()

        if line.startswith('pattern_bet:'):
            bet = line.split("pattern_bet:")[1].strip()

        if line.startswith('pattern_aft:'):
            aft = line.split("pattern_aft:")[1].strip()

        if line.startswith('passive voice:'):
            tmp = line.split("passive voice:")[1].strip()
            if tmp == 'False':
                passive_voice = False
            elif tmp == 'True':
                passive_voice = True

        if line.startswith('\n') and float(score) >= threshold:
            if 'bef' not in locals():
                bef = ''
            if 'aft' not in locals():
                aft = ''
            if passive_voice is True and rel_type in ['acquired',
                                                      'headquarters']:
                r = ExtractedFact(e2, e1, float(score), bef, bet, aft,
                                  sentence, passive_voice, rel_type)
            else:
                r = ExtractedFact(e1, e2, float(score), bef, bet, aft,
                                  sentence, passive_voice, rel_type)
            
            if ("'s parent" in bet or 'subsidiary of' in bet or
                bet == 'subsidiary') and rel_type == 'acquired':
                r = ExtractedFact(e2, e1, float(score), bef, bet, aft,
                                  sentence, passive_voice, rel_type)
            system_output.append(r)

    fileinput.close()
    return system_output


def process_freebase(data, rel_type):
    # Load relationships from Freebase and keep them in the same direction has
    # the output of the extraction system
    """
    # rel_type                   Gold standard directions
    founder_arg2_arg1            PER-ORG
    headquarters_arg1_arg2       ORG-LOC
    acquired_arg1_arg2           ORG-ORG
    contained_by_arg1_arg2       LOC-LOC
    employer_arg2_arg1           PER-ORG
    """

    # store a tuple (entity1, entity2) in a dictionary
    database_1 = defaultdict(list)

    # store in a dictionary per relationship: dict['ent1'] = 'ent2'
    database_2 = defaultdict(list)

    # store in a dictionary per relationship: dict['ent2'] = 'ent1'
    database_3 = defaultdict(list)

    # regex used to clean entities
    numbered = re.compile(r'#[0-9]+$')

    # for the 'founder' relationships don't load those from freebase, as it
    # lists countries (i.e., LOC entities) as founders and not persons
    founder_to_ignore = ['UNESCO', 'World Trade Organization', 'European Union',
                         'United Nations']

    # print('number of tuples in Freebase:', len(fileinput.input(data)))
    # count = 0
    for line in fileinput.input(data):
        if line.startswith('#'):
            continue
        try:
            e1, r, e2 = line.split('\t')[:3]
        except Exception:
            print line
            print line.split('\t')
            sys.exit()

        # ignore some entities, which are Freebase identifiers or are ambigious
        if e1.startswith('/') or e2.startswith('/'):
            continue
        if e1.startswith('m/') or e2.startswith('m/'):
            continue
        if re.search(numbered, e1) or re.search(numbered, e2):
            continue
        if e2.strip() in founder_to_ignore:
            continue
        else:
            if "(" in e1:
                e1 = re.sub(r"\(.*\)", "", e1).strip()
            if "(" in e2:
                e2 = re.sub(r"\(.*\)", "", e2).strip()

            if rel_type == 'founder' or rel_type == 'employer':
                database_1[(e2.strip(), e1.strip())].append(r)
                database_2[e2.strip()].append(e1.strip())
                database_3[e1.strip()].append(e2.strip())
            else:
                database_1[(e1.strip(), e2.strip())].append(r)
                database_2[e1.strip()].append(e2.strip())
                database_3[e2.strip()].append(e1.strip())
       
    return database_1, database_2, database_3


def load_acronyms(data):
    acronyms = defaultdict(list)
    for line in fileinput.input(data):
        parts = line.split('\t')
        acronym = parts[0].strip()
        if "/" in acronym:
            continue
        expanded = parts[-1].strip()
        if "/" in expanded:
            continue
        acronyms[acronym].append(expanded)
    fileinput.close()
    return acronyms


def load_dbpedia(data, database_1, database_2):
    for line in fileinput.input(data):
        e1, rel, e2, p = line.split()
        e1 = e1.split('<http://dbpedia.org/resource/')[1].replace(">", "")
        e2 = e2.split('<http://dbpedia.org/resource/')[1].replace(">", "")
        e1 = re.sub("_", " ", e1)
        e2 = re.sub("_", " ", e2)

        if "(" in e1 or "(" in e2:
            e1 = re.sub("\(.*\)", "", e1)
            e2 = re.sub("\(.*\)", "", e2)

            # store a tuple (entity1, entity2) in a dictionary
            database_1[(e1.strip(), e2.strip())].append(p)

            # store in a dictionary per relationship: dict['ent1'] = 'ent2'
            database_2[e1.strip()].append(e2.strip())

        else:
            e1 = e1.decode("utf8").strip()
            e2 = e2.decode("utf8").strip()
            # store a tuple (entity1, entity2) in a dictionary
            database_1[(e1, e2)].append(p)

            # store in a dictionary per relationship: dict['ent1'] = 'ent2'
            database_2[e1.strip()].append(e2.strip())

    fileinput.close()

    return database_1, database_2


def extract_bigrams(text):
    tokens = word_tokenize(text)
    return [gram[0]+' '+gram[1] for gram in bigrams(tokens)]


def extract_trigrams(text):
    tokens = word_tokenize(text)
    return [gram[0]+' '+gram[1]+' '+gram[2] for gram in ngrams(tokens, 3)]

# ########################################
# Estimations of sets and intersections #
# ########################################
@timecall
def calculate_a(not_in_database, e1_type, e2_type, index, rel_words_unigrams,
                rel_words_bigrams, rel_words_trigrams, system_output_dir):
    m = multiprocessing.Manager()
    queue = m.Queue()
    # num_cpus = multiprocessing.cpu_count()
    num_cpus = NUM_CPUS
    results = [m.list() for _ in range(num_cpus)]
    not_found = [m.list() for _ in range(num_cpus)]

    for r in not_in_database:
        queue.put(r)

    processes = [multiprocessing.Process(
        target=proximity_pmi_a,
        args=(e1_type, e2_type, queue, index, results[i], not_found[i],
              rel_words_unigrams, rel_words_bigrams, rel_words_trigrams)) for i in range(num_cpus)]

    for proc in processes:
        proc.start()
    for proc in processes:
        proc.join()

    a = list()
    for l in results:
        a.extend(l)

    wrong = list()
    for l in not_found:
        wrong.extend(l)

    return a, wrong


@timecall
def calculate_b(output, database_1, database_2, database_3, e1_type, e2_type, system_output_dir):
    # intersection between the system output and the database
    # it is assumed that every fact in this region is correct
    m = multiprocessing.Manager()
    queue = m.Queue()
    # num_cpus = multiprocessing.cpu_count()
    num_cpus = NUM_CPUS
    results = [m.list() for _ in range(num_cpus)]
    no_matches = [m.list() for _ in range(num_cpus)]

    for r in output:
        queue.put(r)

    processes = [multiprocessing.Process(
        target=string_matching_parallel,
        args=(results[i], no_matches[i], database_1, database_2, database_3,
              queue, e1_type, e2_type))
                 for i in range(num_cpus)]

    for proc in processes:
        proc.start()

    for proc in processes:
        proc.join()

    b = set()
    for l in results:
        b.update(l)

    not_found = set()
    for l in no_matches:
        not_found.update(l)

    return b, not_found


@timecall
def calculate_c(corpus, database_1, database_2, database_3, b, e1_type, e2_type,
                rel_type, rel_words_unigrams, rel_words_bigrams, system_output_dir):

    # contains the database facts described in the corpus
    # but not extracted by the system
    #
    # G' = superset of G, cartesian product of all possible entities and
    # relations (i.e., G' = E x R x E)
    # for now, all relationships from a sentence
    print "Building G', a superset of G"
    m = multiprocessing.Manager()
    queue = m.Queue()
    g_dash = m.list()
    # num_cpus = multiprocessing.cpu_count()
    num_cpus = NUM_CPUS
    # check if superset G' for e1_type, e2_type already exists and
    # if G' minus KB for rel_type exists

    # if it exists load into g_dash_set
    if os.path.isfile("superset_" + e1_type + "_" + e2_type + ".pkl"):
        f = open("superset_" + e1_type + "_" + e2_type + ".pkl")
        print "\nLoading superset G'", "superset_" + e1_type + "_" + \
                                       e2_type + ".pkl"
        g_dash_set = cPickle.load(f)
        f.close()

    # else generate G' and G minus D
    else:
        with open(corpus) as f:
            data = f.readlines()
            count = 0
            print "Storing in shared Queue"
            for l in data:
                if count % 50000 == 0:
                    sys.stdout.write(".")
                    sys.stdout.flush()
                queue.put(l)
                count += 1
        print "\nQueue size:", queue.qsize()

        processes = [multiprocessing.Process(
            target=process_corpus,
            args=(queue, g_dash, e1_type, e2_type))
                     for _ in range(num_cpus)]

        print "Extracting all possible " + e1_type + "," + e2_type + \
              " relationships from the corpus"
        print "Running", len(processes), "threads"

        for proc in processes:
            proc.start()

        for proc in processes:
            proc.join()

        print("relationships built", g_dash)
        print len(g_dash), "relationships built"
        g_dash_set = set(g_dash)
        print len(g_dash_set), "unique relationships"
        print "Dumping into file", "superset_" + e1_type + "_" + e2_type + ".pkl"
        f = open("superset_" + e1_type + "_" + e2_type + ".pkl", "wb")
        cPickle.dump(g_dash_set, f)
        f.close()

    # Estimate G \in D, look for facts in G' that a match a fact in the database
    # check if already exists for this particular relationship
    if os.path.isfile(os.path.join(system_output_dir, rel_type + "_g_intersection_d.pkl")) and \
            os.path.isfile(os.path.join(system_output_dir, rel_type + "_g_minus_d.pkl")):
        f = open(os.path.join(system_output_dir, rel_type + "_g_intersection_d.pkl"), "r")
        print "\nLoading G intersected with D", os.path.join(system_output_dir, rel_type + "_g_intersection_d.pkl")
        g_intersect_d = cPickle.load(f)
        f.close()

        f = open(os.path.join(system_output_dir, rel_type + "_g_minus_d.pkl"))
        print "\nLoading superset G' minus D", os.path.join(system_output_dir, rel_type + "_g_minus_d.pkl")
        g_minus_d = cPickle.load(f)
        f.close()

    else:
        print "Estimating G intersection with D"
        g_intersect_d = set()
        print "G':", len(g_dash_set)
        print "Database:", len(database_1.keys())

        # Facts not in the database, to use in estimating set d
        g_minus_d = set()

        queue = manager.Queue()
        results = [manager.list() for _ in range(num_cpus)]
        no_matches = [manager.list() for _ in range(num_cpus)]

        # Load everything into a shared queue
        for r in g_dash_set:
            queue.put(r)

        processes = [multiprocessing.Process(
            target=string_matching_parallel,
            args=(results[i], no_matches[i],
                  database_1, database_2, database_3, queue, e1_type, e2_type))
                     for i in range(num_cpus)]

        for proc in processes:
            proc.start()

        for proc in processes:
            proc.join()

        for l in results:
            g_intersect_d.update(l)

        for l in no_matches:
            g_minus_d.update(l)

        print "Extra filtering: from the intersection of G' with D, " \
              "select only those based on keywords"
        print len(g_intersect_d)
        filtered = set()
        for r in g_intersect_d:
            # print('r.between:', r.between)
            unigrams_bet = word_tokenize(r.between)
            unigrams_bef = word_tokenize(r.before)
            unigrams_aft = word_tokenize(r.after)
            bigrams_bet = extract_bigrams(r.between)
            if any(str(x).lower()in rel_words_unigrams for x in unigrams_bet):
                filtered.add(r)
                continue
            if any(str(x).lower() in rel_words_unigrams for x in unigrams_bef):
                filtered.add(r)
                continue
            if any(str(x).lower() in rel_words_unigrams for x in unigrams_aft):
                filtered.add(r)
                continue
            elif any(x in rel_words_bigrams for x in bigrams_bet):
                filtered.add(r)
                continue
        g_intersect_d = filtered
        print len(g_intersect_d), "relationships in the corpus " \
                                  "which are in the KB"
        if len(g_intersect_d) > 0:
            # dump G intersected with D to file
            f = open(os.path.join(system_output_dir, rel_type + "_g_intersection_d.pkl"), "wb")
            cPickle.dump(g_intersect_d, f)
            f.close()

        print "Extra filtering: from the G' not in D, select only " \
              "those based on keywords"
        filtered = set()
        for r in g_minus_d:
            # print('r.between:', r.between)
            unigrams_bet = word_tokenize(r.between)
            unigrams_bef = word_tokenize(r.before)
            unigrams_aft = word_tokenize(r.after)
            bigrams_bet = extract_bigrams(r.between)
            if any(str(x).lower() in rel_words_unigrams for x in unigrams_bet):
                filtered.add(r)
                continue
            if any(str(x).lower() in rel_words_unigrams for x in unigrams_bef):
                filtered.add(r)
                continue
            if any(str(x).lower() in rel_words_unigrams for x in unigrams_aft):
                filtered.add(r)
                continue
            elif any(x in rel_words_bigrams for x in bigrams_bet):
                filtered.add(r)
                continue
        g_minus_d = filtered
        print len(g_minus_d), "relationships in the corpus not in the KB"
        if len(g_minus_d) > 0:
            # dump G - D to file, relationships in the corpus not in KB
            f = open(os.path.join(system_output_dir, rel_type + "_g_minus_d.pkl"), "wb")
            cPickle.dump(g_minus_d, f)
            f.close()

    # having B and G_intersect_D => |c| = |G_intersect_D| - |b|
    c = g_intersect_d.difference(set(b))
    assert len(g_minus_d) > 0
    return c, g_minus_d


@timecall
def calculate_d(g_minus_d, a, e1_type, e2_type, index, rel_type,
                rel_words_unigrams, rel_words_bigrams, rel_words_trigrams, system_output_dir):

    # contains facts described in the corpus that are not
    # in the system output nor in the database
    #
    # by applying the PMI of the facts not in the database (i.e., G' \in D)
    # we determine |G \ D|, then we can estimate |d| = |G \ D| - |a|
    #
    # |G' \ D|
    # determine facts not in the database, with high PMI, that is,
    # facts that are true and are not in the database

    # check if it was already calculated and stored in disk
    if os.path.isfile(os.path.join(system_output_dir, rel_type + "_high_pmi_not_in_database.pkl")):
        f = open(os.path.join(system_output_dir, rel_type + "_high_pmi_not_in_database.pkl"))
        print "\nLoading high PMI facts not in the database", \
            os.path.join(system_output_dir, rel_type + "_high_pmi_not_in_database.pkl")
        g_minus_d = cPickle.load(f)
        f.close()

    else:
        m = multiprocessing.Manager()
        queue = m.Queue()
        # num_cpus = multiprocessing.cpu_count()
        num_cpus = NUM_CPUS
        results = [m.list() for _ in range(num_cpus)]

        for r in g_minus_d:
            queue.put(r)

        # calculate PMI for r not in database
        processes = [multiprocessing.Process(
            target=proximity_pmi_rel_word,
            args=(e1_type, e2_type, queue, index,
                  results[i], rel_words_unigrams, rel_words_bigrams, rel_words_trigrams))
                     for i in range(num_cpus)]

        for proc in processes:
            proc.start()

        for proc in processes:
            proc.join()

        g_minus_d = set()
        for l in results:
            g_minus_d.update(l)

        print "High PMI facts not in the database", len(g_minus_d)

        # dump high PMI facts not in the database
        if len(g_minus_d) > 0:
            f = open(os.path.join(system_output_dir, rel_type + "_high_pmi_not_in_database.pkl"), "wb")
            print "Dumping high PMI facts not in the database to", \
                os.path.join(system_output_dir, rel_type + "_high_pmi_not_in_database.pkl")
            cPickle.dump(g_minus_d, f)
            f.close()

    return g_minus_d.difference(a)


########################################################################
# Parallelized functions: each function will run as a different process #
########################################################################
def proximity_pmi_rel_word(e1_type, e2_type, queue, index, results,
                           rel_words_unigrams, rel_words_bigrams, rel_words_trigrams):
    idx = open_dir(index)
    count = 0
    distance = MAX_TOKENS_AWAY
    q_limit = 500
    with idx.searcher() as searcher:
        while True:
            try:
                r = queue.get_nowait()
                if count % 50 == 0:
                    print "\n", multiprocessing.current_process(), \
                        "In Queue", queue.qsize(), \
                        "Total Matched: ", len(results)
                if (r.e1, r.e2) not in all_in_database:
                    # if its not in the database calculate the PMI
                    entity1 = "<" + e1_type + ">" + r.e1 + "</" + e1_type + ">"
                    entity2 = "<" + e2_type + ">" + r.e2 + "</" + e2_type + ">"
                    t1 = query.Term('sentence', entity1)
                    t3 = query.Term('sentence', entity2)

                    # Entities proximity query without relational words
                    q1 = spans.SpanNear2(
                        [t1, t3], slop=distance,
                        ordered=True, mindist=1)
                    hits = searcher.search(q1, limit=q_limit)

                    # Entities proximity considering relational words
                    # From the results above count how many contain a
                    # valid relational word

                    hits_with_r = 0
                    hits_without_r = 0
                    for s in hits:
                        sentence = s.get("sentence")
                        s = Sentence(sentence, e1_type, e2_type,
                                     MAX_TOKENS_AWAY, MIN_TOKENS_AWAY,
                                     CONTEXT_WINDOW)

                        for s_r in s.relationships:
                            if r.e1.decode("utf8") == s_r.e1 and \
                                            r.e2.decode("utf8") == s_r.e2:

                                # unigrams_rel_words = word_tokenize(s_r.between)
                                # bigrams_rel_words = extract_bigrams(s_r.between)

                                unigrams_rel_words = word_tokenize(str(" ".join([x[0] for x in s_r.between])).strip())
                                bigrams_rel_words = extract_bigrams(str(" ".join([x[0] for x in s_r.between])).strip())
                                trigrams_rel_words = extract_trigrams(str(" ".join([x[0] for x in s_r.between])).strip())

                                if all(x in not_valid
                                       for x in unigrams_rel_words):
                                    hits_without_r += 1
                                    continue
                                elif any(str(x).lower() in rel_words_unigrams for x in
                                         unigrams_rel_words):

                                    hits_with_r += 1

                                elif any(x in rel_words_bigrams
                                         for x in bigrams_rel_words):
                                    hits_with_r += 1

                                elif any(x in rel_words_trigrams
                                         for x in trigrams_rel_words):
                                    hits_with_r += 1
                                else:
                                    hits_without_r += 1

                    if hits_with_r > 0 and hits_without_r > 0:
                        pmi = float(hits_with_r) / float(hits_without_r)
                        if pmi >= PMI:
                            if len(word_tokenize(str(" ".join([x[0] for x in s_r.between])).strip())) > 0:
                                if word_tokenize(str(" ".join([x[0] for x in s_r.between])).strip())[-1] == 'by':
                                    tmp = s_r.e2
                                    s_r.e2 = s_r.e1
                                    s_r.e1 = tmp                                   
                                
                            results.append(r)
                    
                    elif hits_with_r > 0 and hits_without_r == 0:
                        if len(word_tokenize(str(" ".join([x[0] for x in s_r.between])).strip())) > 0:
                            if word_tokenize(str(" ".join([x[0] for x in s_r.between])).strip())[-1] == 'by':
                                tmp = s_r.e2
                                s_r.e2 = s_r.e1
                                s_r.e1 = tmp
                                
                        results.append(r)

                count += 1
            except Queue.Empty:
                break


def string_matching_parallel(matches, no_matches, database_1, database_2,
                             database_3, queue, e1_type, e2_type):
    count = 0
    while True:
        try:
            r = queue.get_nowait()
            found = False
            count += 1
            if count % 500 == 0:
                print multiprocessing.current_process(), \
                    "In Queue", queue.qsize()

            # check if its in cache, i.e., if tuple was already matched
            if (r.e1, r.e2) in all_in_database:
                matches.append(r)
                found = True

            # check for a relationship with a direct string matching
            if found is False:
                if len(database_1[(r.e1.decode("utf8"),
                                   r.e2.decode("utf8"))]) > 0:
                    matches.append(r)
                    all_in_database[(r.e1, r.e2)] = "Found"
                    found = True

            if found is False:
                # database_2: arg_1 rel list(arg_2)
                # check for a direct string matching with all possible arg2
                # FOUNDER   : r.ent1:ORG   r.ent2:PER
                # DATABASE_1: (ORG,PER)
                # DATABASE_2: ORG   list<PER>
                # DATABASE_3: PER   list<ORG>

                e2 = database_2[r.e1.decode("utf8")]
                if len(e2) > 0:
                    if r.e2 in e2:
                        matches.append(r)
                        all_in_database[(r.e1, r.e2)] = "Found"
                        found = True

            # if a direct string matching occur with arg_2, check for a
            # direct string matching with all possible arg1 entities
            if found is False:
                arg1_list = database_3[r.e2]
                if arg1_list is not None:
                    for arg1 in arg1_list:
                        if e1_type == 'ORG':
                            new_arg1 = re.sub(r" Corporation| Inc\.", "", arg1)
                        else:
                            new_arg1 = arg1

                        # Jaccardi
                        set_1 = set(new_arg1.split())
                        set_2 = set(r.e1.split())

                        jaccardi = \
                            float(len(set_1.intersection(set_2))) / \
                            float(len(set_1.union(set_2)))

                        if jaccardi >= 0.5:
                            matches.append(r)
                            all_in_database[(r.e1, r.e2)] = "Found"
                            found = True

                        # Jaro Winkler
                        elif jaccardi <= 0.5:
                            score = jellyfish.jaro_winkler(
                                # unicode(new_arg1.upper()), unicode(r.e1.upper())
                                new_arg1.upper().decode("utf8"), r.e1.upper().decode("utf8")
                            )
                            if score >= 0.9:
                                matches.append(r)
                                all_in_database[(r.e1, r.e2)] = "Found"
                                found = True

            # if a direct string matching occur with arg_1,
            # check for a direct string matching
            # with all possible arg_2 entities
            if found is False:
                arg2_list = database_2[r.e1]
                if arg2_list is not None:
                    for arg2 in arg2_list:
                        # Jaccardi
                        if e1_type == 'ORG':
                            new_arg2 = re.sub(r" Corporation| Inc\.", "", arg2)
                        else:
                            new_arg2 = arg2
                        set_1 = set(new_arg2.split())
                        set_2 = set(r.e2.split())
                        jaccardi = \
                            float(len(set_1.intersection(set_2))) / \
                            float(len(set_1.union(set_2)))

                        if jaccardi >= 0.5:
                            matches.append(r)
                            all_in_database[(r.e1, r.e2)] = "Found"
                            found = True

                        # Jaro Winkler
                        elif jaccardi <= 0.5:
                            score = jellyfish.jaro_winkler(
                                new_arg2.upper().decode("utf8"), r.e2.upper().decode("utf8")
                            )
                            if score >= 0.9:
                                matches.append(r)
                                all_in_database[(r.e1, r.e2)] = "Found"
                                found = True

            if found is False:
                no_matches.append(r)
                if PRINT_NOT_FOUND is True:
                    print r.e1, '\t', r.e2

        except Queue.Empty:
            break


def proximity_pmi_a(e1_type, e2_type, queue, index, results, not_found,
                    rel_words_unigrams, rel_words_bigrams, rel_words_trigrams):
    idx = open_dir(index)
    count = 0
    q_limit = 500
    with idx.searcher() as searcher:
        while True:
            try:
                r = queue.get_nowait()
                count += 1
                if count % 50 == 0:
                    print multiprocessing.current_process(), \
                        "To Process", queue.qsize(), \
                        "Correct found:", len(results)

                # if its not in the database calculate the PMI
                entity1 = "<" + e1_type + ">" + r.e1 + "</" + e1_type + ">"
                entity2 = "<" + e2_type + ">" + r.e2 + "</" + e2_type + ">"
                t1 = query.Term('sentence', entity1)
                t3 = query.Term('sentence', entity2)

                # First count the proximity (MAX_TOKENS_AWAY) occurrences
                # of entities r.e1 and r.e2
                q1 = spans.SpanNear2([t1, t3],
                                     slop=MAX_TOKENS_AWAY,
                                     ordered=True,
                                     mindist=1)
                hits = searcher.search(q1, limit=q_limit)                

                # Entities proximity considering relational words
                # From the results above count how many contain a
                # valid relational word
                hits_with_r = 0
                hits_without_r = 0

                fact_bet_words_tokens = word_tokenize(r.bet_words)                

                for s in hits:
                    sentence = s.get("sentence")
                    s = Sentence(sentence, e1_type, e2_type, MAX_TOKENS_AWAY,
                                 MIN_TOKENS_AWAY, CONTEXT_WINDOW)
                    for s_r in s.relationships:
                        if r.e1.decode("utf8").lower() == s_r.e1.lower() and\
                                        r.e2.decode("utf8").lower() == s_r.e2.lower():                            
                            unigrams_bef_words = word_tokenize(str(" ".join([x[0] for x in s_r.before])).strip())
                            unigrams_bet_words = word_tokenize(str(" ".join([x[0] for x in s_r.between])).strip())
                            unigrams_aft_words = word_tokenize(str(" ".join([x[0] for x in s_r.after])).strip())
                            bigrams_rel_words = extract_bigrams(str(" ".join([x[0] for x in s_r.between])).strip())
                            trigrams_rel_words = extract_trigrams(str(" ".join([x[0] for x in s_r.between])).strip())

                            if fact_bet_words_tokens == unigrams_bet_words:
                                hits_with_r += 1

                            elif any(str(x).lower() in rel_words_unigrams for x in unigrams_bef_words):
                                hits_with_r += 1

                            elif any(str(x).lower() in rel_words_unigrams
                                     for x in unigrams_bet_words):
                                hits_with_r += 1

                            elif any(str(x).lower() in rel_words_unigrams
                                     for x in unigrams_aft_words):
                                hits_with_r += 1

                            elif any(x in rel_words_bigrams for x in bigrams_rel_words):
                                hits_with_r += 1

                            elif any(x in rel_words_trigrams for x in trigrams_rel_words):
                                hits_with_r += 1

                            elif rel_words_bigrams == bigrams_rel_words:
                                hits_with_r += 1
                            else:
                                hits_without_r += 1

                if hits_with_r > 0 and hits_without_r > 0:
                    pmi = float(hits_with_r) / float(hits_without_r)
                    if pmi >= PMI:
                        results.append(r)
                    else:
                        if str(r.bet_words).strip() in rel_words_bigrams:
                            results.append(r)
                        elif any(x in rel_words_bigrams for x in extract_bigrams(str(r.bet_words).lower().strip())):
                            results.append(r)
                        elif any(str(x).lower() in rel_words_unigrams for x in word_tokenize(r.bet_words)):
                            results.append(r)
                        elif str(r.bet_words).strip() in rel_words_trigrams or str(r.bet_words).strip() == "'s purchase of":
                            results.append(r)
                        elif str(r.bef_words).strip() in rel_words_bigrams:
                            results.append(r)
                        elif any(str(x).lower() in rel_words_unigrams for x in word_tokenize(r.bef_words)):
                            results.append(r)
                        elif str(r.bef_words).strip() in rel_words_trigrams:
                            results.append(r)
                        elif str(r.aft_words).strip() in rel_words_bigrams:
                            results.append(r)
                        elif any(str(x).lower() in rel_words_unigrams for x in word_tokenize(r.aft_words)):
                            results.append(r)
                        elif str(r.aft_words).strip() in rel_words_trigrams:
                            results.append(r)
                        else:
                            print('pmi=%s r.e1=%s r.e2=%s r.bet_words:%s' % (pmi, r.e1, r.e2, r.bet_words))
                            print('hits_with_r=%s hits_without_r=%s' % (hits_with_r, hits_without_r))
                            not_found.append(r)

                elif hits_with_r > 0 and hits_without_r == 0:
                    results.append(r)
                else:
                    if len(hits) == 0:
                        
                        if any(str(x).lower() in rel_words_unigrams for x in word_tokenize(str(r.bet_words).strip())):
                            results.append(r)
                        else:
                            not_found.append(r)
                    else:
                        if str(r.bet_words).strip() in rel_words_bigrams:
                            results.append(r)
                        elif any(str(x).lower() in rel_words_unigrams for x in word_tokenize(r.bet_words)):
                            results.append(r)
                        elif str(r.bet_words).strip() in rel_words_trigrams or str(r.bet_words).strip() == "'s purchase of":
                            results.append(r)
                        else:
                            print('Hits r.e1=%s r.e2=%s r.bet_words:%s' % (r.e1, r.e2, r.bet_words))
                            not_found.append(r)
                count += 1

            except Queue.Empty:
                break


def main():
    # "Automatic Evaluation of Relation Extraction Systems on Large-scale"
    # https://akbcwekex2012.files.wordpress.com/2012/05/8_paper.pdf
    #
    # S  - system output
    # D  - database (freebase)
    # G  - will be the resulting ground truth
    # G' - superset, contains true facts, and wrong facts
    # a  - contains correct facts from the system output
    #
    # b  - intersection between the system output and the
    #      database (i.e., freebase),
    #      it is assumed that every fact in this region is correct
    # c  - contains the database facts described in the corpus
    #      but not extracted by the system
    # d  - contains the facts described in the corpus that are not
    #      in the system output nor in the database
    #
    # Precision = |a|+|b| / |S|
    # Recall    = |a|+|b| / |a| + |b| + |c| + |d|
    # F1        = 2*P*R / P+R

    if len(sys.argv) == 1:
        print "No arguments"
        print "Use: large_scale_evaluation_freebase.py threshold system_output rel_type database root_dir corpus index"
    
		# python large_scale_evaluation_freebase.py 0.5 ../../data/output/BREE/REL_ACQUIRED_ORG_ORG/relationships_baseline.txt 
		# acquired ../../resources/freebase-easy-14-04-14/freebase_facts.txt ../../data/ ../../data/input/sentences.txt ./index_full
		
        print "\n"
        sys.exit(0)

    threshold = float(sys.argv[1])
    rel_type = sys.argv[3]
    # root_dir = sys.argv[5]
    # root_dir = sys.argv[6]
    system_output_dir = os.path.dirname(sys.argv[2])

    # load relationships extracted by the system
    system_output = process_output(sys.argv[2], threshold, rel_type)
    print "Relationships score threshold :", threshold
    print "System output relationships   :", len(system_output)


    if os.path.isfile(os.path.join(os.path.dirname(sys.argv[4]), rel_type+"_"+"db1.pkl")):
        print('loading the saved Freebase databases.....')

        f = open(os.path.join(os.path.dirname(sys.argv[4]), rel_type+"_"+"db1.pkl"), "rb")
        database_1 = cPickle.load(f)
        f.close()

        f = open(os.path.join(os.path.dirname(sys.argv[4]), rel_type + "_" + "db2.pkl"), "rb")
        database_2 = cPickle.load(f)
        f.close()

        f = open(os.path.join(os.path.dirname(sys.argv[4]), rel_type + "_" + "db3.pkl"), "rb")
        database_3 = cPickle.load(f)
        f.close()

    else:
        # load freebase relationships as the database
        database_1, database_2, database_3 = process_freebase(sys.argv[4], rel_type)
        print "Freebase relationships loaded :", len(database_1.keys())
        # dump the files: for faster access
        f = open(os.path.join(os.path.dirname(sys.argv[4]), rel_type + "_" + "db1.pkl"), "wb")
        cPickle.dump(database_1, f)
        f.close()

        f = open(os.path.join(os.path.dirname(sys.argv[4]), rel_type + "_" + "db2.pkl"), "wb")
        cPickle.dump(database_2, f)
        f.close()

        f = open(os.path.join(os.path.dirname(sys.argv[4]), rel_type + "_" + "db3.pkl"), "wb")
        cPickle.dump(database_3, f)
        f.close()

    # corpus from which the system extracted relationships
    
    corpus = sys.argv[6]

    # index to be used to estimate proximity PMI
    index = "./index_full"

    # entities semantic type
    rel_words_unigrams = None
    rel_words_bigrams = None
    rel_words_trigrams = None

    if rel_type == 'founder':
        e1_type = "ORG"
        e2_type = "PER"
        rel_words_unigrams = founded_unigrams
        rel_words_bigrams = founded_bigrams
        rel_words_trigrams = founded_trigrams

    elif rel_type == 'acquired':
        e1_type = "ORG"
        e2_type = "ORG"
        rel_words_unigrams = acquired_unigrams
        rel_words_bigrams = acquired_bigrams
        rel_words_trigrams = acquired_trigrams

    elif rel_type == 'headquarters':
        # load dbpedia relationships
        # print "Loading extra DBPedia relationships for", rel_type
        # load_dbpedia(sys.argv[5], database_1, database_2)
        e1_type = "ORG"
        e2_type = "LOC"
        rel_words_unigrams = headquarters_unigrams
        rel_words_bigrams = headquarters_bigrams
        rel_words_trigrams = headquarters_trigrams

    elif rel_type == 'contained_by':
        e1_type = "LOC"
        e2_type = "LOC"

    elif rel_type == 'employer':
        e1_type = "ORG"
        e2_type = "PER"
        rel_words_unigrams = employment_unigrams
        rel_words_bigrams = employment_bigrams
        rel_words_trigrams = employment_trigrams

    else:
        print "Invalid relationship type", rel_type
        print "Use: founder, acquired, headquarters, employer"
        sys.exit(0)

    print "\nRelationship Type:", rel_type
    print "Arg1 Type:", e1_type
    print "Arg2 Type:", e2_type

    print "\nCalculating set B: intersection between system output and database"
    b, not_in_database = calculate_b(system_output, database_1, database_2,
                                     database_3, e1_type, e2_type, system_output_dir=system_output_dir)

    print "System output      :", len(system_output)
    print "Found in database  :", len(b)
    print "Not found          :", len(not_in_database)
    assert len(system_output) == len(not_in_database) + len(b)


    print "\nCalculating set A: correct facts from system output not in " \
          "the database (proximity PMI)"
    a, not_found = calculate_a(not_in_database, e1_type, e2_type, index,
                               rel_words_unigrams, rel_words_bigrams, rel_words_trigrams, system_output_dir=system_output_dir)

    print "System output      :", len(system_output)
    print "Found in database  :", len(b)
    print "Correct in corpus  :", len(a)
    print "Not found          :", len(not_found)
    print "\n"
    assert len(system_output) == len(a) + len(b) + len(not_found)


    # Estimate G \intersected D = |b| + |c|, looking for relationships in G'
    # that match a relationship in D, once we have G \in D and |b|, |c| can be
    # derived by: |c| = |G \in D| - |b| G' = superset of G, cartesian product
    # of all possible entities and relations (i.e., G' = E x R x E)
    print "\nCalculating set C: database facts in the corpus but not " \
          "extracted by the system"
    c, g_minus_d = calculate_c(corpus, database_1, database_2, database_3, b,
                               e1_type, e2_type, rel_type, rel_words_unigrams,
                               rel_words_bigrams, system_output_dir=system_output_dir)
    assert len(c) > 0

    uniq_c = set()
    for r in c:
        uniq_c.add((r.e1, r.e2))

    # By applying the PMI of the facts not in the database (i.e., G' \in D)
    # we determine |G \ D|, then we can estimate |d| = |G \ D| - |a|
    print "\nCalculating set D: facts described in the corpus not in " \
          "the system output nor in the database"
    d = calculate_d(g_minus_d, a, e1_type, e2_type, index, rel_type,
                    rel_words_unigrams, rel_words_bigrams, rel_words_trigrams, system_output_dir=system_output_dir)

    print "System output      :", len(system_output)
    print "Found in database  :", len(b)
    print "Correct in corpus  :", len(a)
    print "Not found          :", len(not_found)
    print "\n"
    assert len(d) > 0

    uniq_d = set()
    for r in d:
        uniq_d.add((r.e1, r.e2))

    print "|a| =", len(a)
    print "|b| =", len(b)
    print "|c| =", len(c), "(", len(uniq_c), ")"
    print "|d| =", len(d), "(", len(uniq_d), ")"
    print "|S| =", len(system_output)
    print "|G| =", len(set(a).union(set(b).union(set(c).union(set(d)))))
    print "Relationships not found:", len(set(not_found))

    # Write relationships not found in the Database nor with high PMI
    # relational words to disk
    f = open(os.path.join(system_output_dir, rel_type + "_" + sys.argv[2][-11:][:-4] + "_negative.txt"), "w")
    for r in sorted(set(not_found), reverse=True):
        f.write('instance :' + r.e1 + '\t' + r.e2 + '\t' + str(r.score) +
                '\n')
        f.write('sentence :' + r.sentence + '\n')
        f.write('bef_words:' + r.bef_words + '\n')
        f.write('bet_words:' + r.bet_words + '\n')
        f.write('aft_words:' + r.aft_words + '\n')
        f.write('\n')
    f.close()

    # Write all correct relationships (sentence, entities and score) to file
    f = open(os.path.join(system_output_dir, rel_type + "_" + sys.argv[2][-11:][:-4] + "_positive.txt"), "w")
    for r in sorted(set(a).union(b), reverse=True):
        f.write('instance :' + r.e1 + '\t' + r.e2 + '\t' + str(r.score) +
                '\n')
        f.write('sentence :' + r.sentence + '\n')
        f.write('bef_words:' + r.bef_words + '\n')
        f.write('bet_words:' + r.bet_words + '\n')
        f.write('aft_words:' + r.aft_words + '\n')
        f.write('\n')
    f.close()

    a = set(a)
    b = set(b)
    output = set(system_output)
    if len(output) == 0:
        print "\nPrecision   : 0.0"
        print "Recall      : 0.0"
        print "F1          : 0.0"
        print "\n"
    elif float(len(a) + len(b)) == 0:
        print "\nPrecision   : 0.0"
        print "Recall      : 0.0"
        print "F1          : 0.0"
        print "\n"
    else:
        precision = float(len(a) + len(b)) / float(len(output))
        recall = float(len(a) + len(b)) / float(len(a) + len(b) + len(uniq_c) +
                                                len(uniq_d))
        f1 = 2 * (precision * recall) / (precision + recall)
        print "\nPrecision   : ", precision
        print "Recall      : ", recall
        print "F1          : ", f1
        print "\n"

if __name__ == "__main__":
    main()
